import torch
import torch.utils.data
import numpy as np
from torch import nn, optim
from torch.nn import functional as F

from learned_sigma import get_distribution_class, softclip
from distributions import AttrDict


class VAE(nn.Module):
    def __init__(self, device='cuda', img_channels=3, args=None):
        super(VAE, self).__init__()
        self.batch_size = args.batch_size
        self.device = device
        self.z_dim = args.z_dim
        self.img_channels = img_channels
        self.args = args
        self.distr = get_distribution_class(self.args.distribution, self.args.sigma_mode)[0]

        self.sigma_decoder = self.build_sigma_decoder(args.sigma, args.learn_beta, args.sigma_mode)
        
        self.build_network()
        
    def build_network(self):
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, self.z_dim)
        self.fc22 = nn.Linear(400, self.z_dim)
        self.fc3 = nn.Linear(self.z_dim, 400)
        self.fc4 = nn.Linear(400, 784)
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x.view(-1, 784)))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, x=None):
        outputs = self.decode_mean(z)
        outputs.update(self.sigma_decoder(z, x, outputs.safe.mu))
        return outputs
    
    def decode_mean(self, z):
        h3 = F.relu(self.fc3(z))
        mean = torch.sigmoid(self.fc4(h3))
        return AttrDict(mu=mean, mle=mean)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, x), mu, logvar
    
    def sample(self, n=None):
        n = n or self.batch_size
        sample = torch.randn(n, self.z_dim).to(self.device)
        sample = self.decode(sample).mle
        return sample
    
    def build_sigma_decoder(self, sigma, learn_beta, sigma_mode):
        if sigma_mode == 'scalar_fixed':
            return ConstantSigmaDecoder(sigma, learn_beta)
        if 'optimal' in sigma_mode:
            return OptimalSigmaDecoder(sigma, learn_beta, sigma_mode)
        elif sigma_mode == 'scalar':
            return ScalarSigmaDecoder(sigma, learn_beta)
        elif sigma_mode == 'posthoc':
            return PostHocSigmaDecoder(sigma, learn_beta)
        elif sigma_mode == 'posthoc_optimal':
            return PostHocOptimalSigmaDecoder(sigma, learn_beta)
        
        return DummySigmaDecoder()
        
    def reconstruction_loss(self, outputs, x, mu, logvar, phase):
        losses = AttrDict()
        
        if 'gaussian' in self.args.distribution:
            
            recon_mu, log_sigma = outputs.mu, outputs.sigma
            log_sigma = softclip(log_sigma, -6)
        
            # if args.sigma_mode == 'optimal_log':
            #     rec = ((((x.view(-1, 784) - recon_mu) ** 2).log() + 1 + 2 * np.pi) / 2).sum()
            # else:
            if x.numel() == recon_mu.numel():
                # Attempt to normalize the shape
                x = x.view(recon_mu.shape)
        
            rec = self.distr(recon_mu, log_sigma).nll(x)
            
            if 'sigma_est' in outputs:
                # posthoc decoder
                logsigma_est = outputs.sigma_est
                # losses.log_sigma = ((log_sigma.detach() - logsigma_est) ** 2).mean()
                losses.sigma_est_rec = self.distr(recon_mu.detach(), logsigma_est).nll(x).sum()
            
        elif self.args.distribution == 'bernoulli':
            if x.numel() == outputs.mu.numel():
                # Attempt to normalize the shape
                x = x.view(outputs.mu.shape)
                
            rec = F.binary_cross_entropy(outputs.mu, x, reduction='sum')
        elif 'categorical' in self.args.distribution:
            rec = self.distr(log_p=outputs.log_prob).nll(x * 2 - 1)
        elif self.args.distribution == 'beta':
            rec = outputs.d.nll(x)

        if torch.isnan(rec).any() or torch.isinf(rec).any():
            import pdb; pdb.set_trace()
        rec = rec.sum()
        
        return losses, rec

    def loss_function(self, recon_x, x, mu, logvar, phase='train'):
        # Reconstruction + KL divergence losses summed over all elements and batch
        losses, rec = self.reconstruction_loss(recon_x, x, mu, logvar, phase=phase)
    
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
        losses.update(AttrDict(rec=rec, KLD=KLD, elbo=rec + KLD, total=rec + KLD))
    
        if 'sigma_est_rec' in losses:
            losses.total = losses.elbo + losses.sigma_est_rec
    
        return losses
    
    def kl_divergence_unit(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())


class ConstantSigmaDecoder(nn.Module):
    def __init__(self, sigma, learn_beta):
        super().__init__()

        log_sigma = nn.Parameter(torch.full((1,), np.log(sigma)).cuda()[0], requires_grad=bool(learn_beta))
        self.log_sigma = log_sigma
        
    def forward(self, *args):
        return AttrDict(sigma=self.log_sigma)


class OptimalSigmaDecoder(nn.Module):
    def __init__(self, sigma, learn_beta, sigma_mode):
        super().__init__()
        self.sigma_mode = sigma_mode
        
    def forward(self, z, x, x_hat):
        outputs = AttrDict()
        
        if x is not None:
            ids = list(range(len(x_hat.shape)))
            if self.sigma_mode != 'optimal_constant':
                ids = ids[1:]
            outputs.sigma = ((x_hat - x) ** 2).mean(ids, keepdim=True).sqrt().log()
        return outputs


class ScalarSigmaDecoder(nn.Module):
    def __init__(self, *args):
        super().__init__()
        
        self.net = nn.Sequential(nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 1))
        self.net = nn.Sequential(nn.Linear(20, 128), nn.ReLU(), nn.Linear(128, 1))
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1))
    
    def forward(self, z, x=None, x_hat=None):
        log_sigma = self.net(z)
        
        return AttrDict(sigma=log_sigma[..., None, None])


class PostHocSigmaDecoder(nn.Module):
    def __init__(self, sigma, learn_beta):
        super().__init__()

        self.net = nn.Sequential(nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 1))
        self.net = nn.Sequential(nn.Linear(20, 128), nn.ReLU(), nn.Linear(128, 1))
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1))
        
        log_sigma = nn.Parameter(torch.full((1,), np.log(sigma)).cuda()[0], requires_grad=bool(learn_beta))
        self.log_sigma = log_sigma
    
    def forward(self, z, x=None, x_hat=None):
        log_sigma_est = self.net(z.detach())
        
        if not self.training:
            # At test, return the estimated sigma
            return AttrDict(sigma=log_sigma_est[..., None, None])
        
        return AttrDict(sigma=self.log_sigma, sigma_est=log_sigma_est[..., None, None])


class PostHocOptimalSigmaDecoder(nn.Module):
    def __init__(self, sigma, learn_beta):
        super().__init__()
        
        self.net = nn.Sequential(nn.Linear(20, 20), nn.ReLU(), nn.Linear(20, 1))
        self.net = nn.Sequential(nn.Linear(20, 128), nn.ReLU(), nn.Linear(128, 1))
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1))
        
    def forward(self, z, x=None, x_hat=None):
        log_sigma_est = self.net(z.detach())
    
        if not self.training or x is None:
            # At test, return the estimated sigma
            return AttrDict(sigma=log_sigma_est[..., None, None])
    
        ids = list(range(len(x.shape)))[1:]
        sigma = ((x - x_hat) ** 2).mean(ids, keepdim=True).sqrt()
        log_sigma = sigma.log()
    
        return AttrDict(sigma=log_sigma, sigma_est=log_sigma_est[..., None, None])


class DummySigmaDecoder:
    def __init__(self, *args, **kwargs):
        pass
    
    def __call__(self, *args, **kwargs):
        return {}
